# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import weakref
from typing import Any, Iterator, Optional, cast

from torch import nn


class RemovableHookHandle:
    """
    A handle to manage removable hooks that are stored in a dictionary within a PyTorch ModuleDict.
    The handle allows safe removal of specific hooks based on their unique identifiers.

    :param storage_ref: A weak reference to the storage (an nn.ModuleDict) containing the hooks.
    :param key: The key in the storage dict corresponding to a set of hooks.
    :param id: A unique identifier for the specific hook within the set.
    """

    def __init__(self, storage: nn.ModuleDict, key: str, id: str) -> None:
        self.storage_ref = weakref.ref(storage)
        self.key = key
        self.id = id

    def remove(self) -> None:
        """
        Removes the hook with the specified id from the storage, if it exists.
        If the set of hooks under a particular key becomes empty after removal, the key is also deleted.
        Steps:
        - Checks if the storage still exists (i.e., if it's not garbage collected).
        - Verifies if the key and id are present in the storage.
        - If hook exists, hook will removed.
        - If no hooks left under the key, the key is also removed from the storage.
        """
        storage: Optional[nn.ModuleDict] = self.storage_ref()
        if storage is None or self.key not in storage:
            # hook storage has been garbage collected or key is not present
            return

        hooks_dict = storage[self.key]
        if not isinstance(hooks_dict, nn.ModuleDict):
            msg = f"Expected nn.ModuleDict for key={self.key}, got {type(hooks_dict)}"
            raise TypeError(msg)

        if self.id not in hooks_dict:
            # hook with the specified id was already removed
            return

        del hooks_dict[self.id]
        if not storage[self.key]:
            del storage[self.key]


class HookStorage(nn.Module):
    """
    A module for storing and executing hooks.

    :param pre_hooks: A instance of nn.ModuleDict for storing pre-hooks.
    :param post_hooks: A instance of nn.ModuleDict for storing post-hooks.
    """

    def __init__(self) -> None:
        """
        Initialize an empty HookStorage.
        """
        super().__init__()
        self.pre_hooks: nn.ModuleDict = nn.ModuleDict()
        self.post_hooks: nn.ModuleDict = nn.ModuleDict()

    @staticmethod
    def _generate_key(op_name: str, port_id: int) -> str:
        """
        Return key for module dict that generated by operation name and port id.

        :param op_name: The operation name the hook is associated with.
        :param port_id: The port ID the hook is associated with.
        :return: Generate key for module dict.
        """
        op_name = op_name.replace(".", ":")
        return f"{op_name}__{port_id}"

    @staticmethod
    def _get_next_hook_id(hooks_dict: nn.ModuleDict) -> str:
        """
        Determines the next available hook ID by finding the highest existing hook ID and incrementing it.

        :param hooks_dict: A dictionary containing existing hooks.
        :return: The next available hook ID as a string. Starts from '0' if no hooks exist.
        """
        if not hooks_dict:
            return "0"
        return str(max([int(k) for k in hooks_dict]) + 1)

    @classmethod
    def _insert_hook(
        cls, storage_dict: nn.ModuleDict, op_name: str, port_id: int, hook: nn.Module
    ) -> RemovableHookHandle:
        """
        Inserts a hook into the storage under the appropriate key.

        :param storage_dict: The storage of hook.
        :param op_name: The operation name the hook is associated with.
        :param port_id: The port ID the hook is associated with.
        :param hook: The hook module to be stored.
        :return: A handle that can be used to remove the hook later.
        """
        hook_key = cls._generate_key(op_name, port_id)
        if hook_key not in storage_dict:
            storage_dict[hook_key] = nn.ModuleDict()

        hooks_dict = storage_dict[hook_key]
        if not isinstance(hooks_dict, nn.ModuleDict):
            msg = f"Expected nn.ModuleDict for key={hook_key}, got {type(hooks_dict)}"
            raise TypeError(msg)

        hook_id = cls._get_next_hook_id(hooks_dict)
        hooks_dict[hook_id] = hook

        return RemovableHookHandle(storage_dict, hook_key, hook_id)

    def register_pre_function_hook(self, op_name: str, port_id: int, hook: nn.Module) -> RemovableHookHandle:
        """
        Registers a pre-function hook to be executed before a specific operation.

        :param op_name: The operation name the hook is associated with.
        :param port_id: The port ID the hook is associated with.
        :param hook: The pre-function hook to be stored.
        :return: A handle for removing the registered hook.
        """
        return self._insert_hook(self.pre_hooks, op_name, port_id, hook)

    def register_post_function_hook(self, op_name: str, port_id: int, hook: nn.Module) -> RemovableHookHandle:
        """
        Registers a post-function hook to be executed before a specific operation.

        :param op_name: The operation name the hook is associated with.
        :param port_id: The port ID the hook is associated with.
        :param hook: The pre-function hook to be stored.
        :return: A handle for removing the registered hook.
        """
        return self._insert_hook(self.post_hooks, op_name, port_id, hook)

    @classmethod
    def _execute_hooks(cls, storage_dict: nn.ModuleDict, op_name: str, port_id: int, value: Any) -> Any:
        """
        Executes all hooks of a given type (pre or post) for a specific operation and port,
        passing and potentially modifying a value.

        :param storage_dict: The storage of hook.
        :param op_name: The operation name the hooks are associated with.
        :param port_id: The port ID the hooks are associated with.
        :param value: The input value to be passed through the hooks.
        :return: The modified value after all hooks have been applied.
        """
        hook_key = cls._generate_key(op_name, port_id)
        if hook_key not in storage_dict:
            return value
        hooks_dict = cast(nn.ModuleDict, storage_dict[hook_key])
        for hook in hooks_dict.values():
            value = hook(value)
        return value

    def execute_pre_function_hooks(self, op_name: str, port_id: int, value: Any) -> Any:
        """
        Executes all pre-function hooks for a given operation and port.

        :param op_name: The operation name the hooks are associated with.
        :param port_id: The port ID the hooks are associated with.
        :param value: The input value to be passed through the pre-function hooks.
        :return: The value after all pre-function hooks have been executed.
        """
        return self._execute_hooks(self.pre_hooks, op_name, port_id, value)

    def execute_post_function_hooks(self, op_name: str, port_id: int, value: Any) -> Any:
        """
        Executes all post-function hooks for a given operation and port.

        :param op_name: The operation name the hooks are associated with.
        :param port_id: The port ID the hooks are associated with.
        :param value: The input value to be passed through the pre-function hooks.
        :return: The value after all post-function hooks have been executed.
        """
        return self._execute_hooks(self.post_hooks, op_name.replace(".", ":"), port_id, value)

    def named_hooks(self, prefix: str = "", remove_duplicate: bool = True) -> Iterator[tuple[str, nn.Module]]:
        """
        Retrieve named hook modules from the model.

        :param prefix: Prefix to filter named modules. Default is "".
        :param remove_duplicate: Whether to remove duplicate modules. Default is False.
        :return: Name and module pairs.
        """
        for name, module in self.named_modules(remove_duplicate=remove_duplicate):
            # Expected depths of target hook module is 2
            # <2 - ModuleDicts in HookStorage, >2 - submodules of hooks
            if name.count(".") == 2:
                yield f"{prefix}.{name}" if prefix else name, cast(nn.Module, module)

    def delete_hook(self, hook_name: str) -> None:
        """
        Deletes a hook from the storage corresponding to the specified name.

        :param hook_name: The name of the hook to be deleted, which includes
            information about the hook type, operation name, and port ID.

        :raises ValueError: If no hook is found for the given hook name.
        :raises ValueError: If the specified hook instance cannot be located for the provided hook name.
        """
        hook_type, op_name, port_id = decode_hook_name(hook_name)
        storage_dict = getattr(self, hook_type)
        hook_key = self._generate_key(op_name, port_id)
        if hook_key not in storage_dict:
            msg = f"No hook was found for a given hook name={hook_name}"
            raise ValueError(msg)

        hook_id = hook_name.split(".")[-1]
        if hook_id not in storage_dict[hook_key]:
            msg = f"No hook was found for a given hook name={hook_name} and hook id={hook_name}"
            raise ValueError(msg)

        del storage_dict[hook_key][hook_id]
        if not storage_dict[hook_key]:
            del storage_dict[hook_key]


def decode_hook_name(hook_name: str) -> tuple[str, str, int]:
    """
    Decodes a name of the hook to extract the operation name and port ID.

    :param hook_name: The name of the hook that returns from HookStorage().named_hooks().
    :return: Hook type, operation name and port id.
    """
    splitted = hook_name.split(".")
    if len(splitted) < 3:
        msg = f"Invalid hook name, name should contain at least 3 parts, got {hook_name}"
        raise ValueError(msg)
    hook_type = splitted[-3]
    if hook_type not in ["pre_hooks", "post_hooks"]:
        msg = f"Invalid hook name, name should contain 'pre_hooks' or 'post_hooks', got {hook_name}"
        raise ValueError(msg)
    hook_key = splitted[-2]
    if "__" not in hook_key:
        msg = f"Invalid hook name, hook_key expect op_name and port_id, got {hook_key}"
        raise ValueError(msg)
    op_name, port_id = hook_key.rsplit("__", 1)
    return hook_type, op_name.replace(":", "."), int(port_id)
